import torch
import torch.nn as nn
import networkx as nx
import torch.nn.functional as F
import time
from concurrent.futures import ProcessPoolExecutor, as_completed
from resnet import *
from model_generator import get_mobilenetv2, get_mobilenetv3, get_shufflenet
from model_generator import get_resnet101, get_resnet152, get_resnet18, get_resnet34, get_resnet50
from model_generator import get_vgg11, get_vgg13, get_vgg16, get_vgg19
import timm
import pandas as pd


# def generate_model_paths():
#     base_path = './trained_models'
#     models = [
#         'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
#         'vgg11', 'vgg13', 'vgg16', 'vgg19',
#         'shufflenet', 'mobilenetv2', 'mobilenetv3'
#     ]
    
#     model_paths = {
#         model: f'{base_path}/{model}_cifar10.pth' for model in models
#     }
    
#     return model_paths

# # 生成模型路径字典
# pt = generate_model_paths()

# =============================注意，一次要改3个位置==================================
# model = get_vgg19()
# model.load_state_dict(torch.load(pt['vgg19']))
# model = timm.create_model('tinynet_a.in1k', num_classes = 10)
# G = nx.read_graphml('/data/liuruiheng/neural-graph/graphCons/model_graph/tinynet_a.in1k_new_mlp_method_graph_layer_level_dig.graphml')


input_csv_file = '/data/liuruiheng/TransformerLearning/mismatching_models_20m.csv'
df = pd.read_csv(input_csv_file)

# 获取模型名称列表
model_names = df['model_name'].tolist()
model_names = model_names[194:]


def findPairs(idx, str1, str2, match, s_name):
    # 获取输入的两个key对于字典match中value的长度(value类型为list)
    len1 = len(match[str1])
    len2 = len(match[str2])
    
    # 初始化返回的索引
    idx1, idx2 = -1, -1

    # 1. 找到在idx之后的第一个满足len(match[strx])==len1的strx
    for i in range(idx + 1, len(s_name)):
        strx = f'{s_name[i]}_in'
        if 'pool' in strx:
            continue
        if len(match[strx]) == len1:
            idx1 = i
            break

    # 2. 找到idx之前的第一个满足len(match[stry])==len2的stry
    for i in range(idx - 1, -1, -1):
        stry = f'{s_name[i]}_out'
        if 'pool' in stry:
            continue
        if len(match[stry]) == len2:
            idx2 = i
            break

    return idx1, idx2



for model_name in model_names:
# 别忘了最后的保存路径
    print(f'{model_name} is processing')
    pt = f'/data/liuruiheng/neural-graph/graphCons/model_graph/{model_name}_new_mlp_method_graph_layer_level_dig.graphml'
    G = nx.read_graphml(pt)
    model = timm.create_model(model_name, num_classes = 10)
    # ================================准备各类型节点======================================
    # 需要根据模型自定义
    next_nodes = []
    current_nodes = []

    g_node_all = [string for string in G.nodes]
    g_node_bn = [string for string in G.nodes if 'bn' in string]
    g_node_shortcut = [string for string in G.nodes if 'downsample' in string and 'bn' not in string]
    g_node_normal = list(set(g_node_all) - set(g_node_bn) - set(g_node_shortcut))


    module_name_bn = []
    module_name_shortcut = []
    module_name_normal = []

    # if 'layer3.5.conv1_in_1' in g_node_all:
    #     print('yes')

    # print(g_node_bn)
    # print(g_node_shortcut)
    # print(g_node_normal)
    # print(f'total nodes num are {len(g_node_shortcut)+len(g_node_bn)+len(g_node_normal)}')
    # ====================================================================================


    # ===============================获取各层所有节点======================================
    matches = {}
    for name,module in model.named_modules():
        lenth = len(name)
        if isinstance(module, (nn.Linear,nn.Conv1d, nn.Conv2d, nn.Conv3d)):
            if 'downsample' in name:
                matches[f'{name}_in'] = sorted([string for string in g_node_shortcut if f'{name}_in' in string[:lenth+3]])
                matches[f'{name}_out'] = sorted([string for string in g_node_shortcut if f'{name}_out' in string[:lenth+4]])
            else:
                matches[f'{name}_in'] = sorted([string for string in g_node_normal if f'{name}_in' in string[:lenth+3]])
                matches[f'{name}_out'] = sorted([string for string in g_node_normal if f'{name}_out' in string[:lenth+4]])
        if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.LayerNorm)):
            if 'downsample' in name:
                matches[f'{name}_bn_mean'] = sorted([string for string in g_node_bn if f'{name}_bn_mean' in string[:lenth+8]])
                matches[f'{name}_bn_var'] = sorted([string for string in g_node_bn if f'{name}_bn_var' in string[:lenth+7]])
            else:
                matches[f'{name}_mean'] = sorted([string for string in g_node_bn if f'{name}_mean' in string[:lenth+5]])
                matches[f'{name}_var'] = sorted([string for string in g_node_bn if f'{name}_var' in string[:lenth+4]])
        if isinstance(module, (nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d, nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d)):
            matches[f'{name}_pool'] = sorted([string for string in g_node_normal if f'{name}_pool' in string[:lenth+5]])
    # print(matches[f'stem.pool_pool'])
    # print()
    
    # =========================================================================================


    # =====================================连接各层=============================================
    # 该模块需要根据模型自定义

    sort_name = []
    itme = 0
    for name, module in model.named_modules():
        if 'downsample' not in name and isinstance(module, (nn.Linear,nn.Conv1d, nn.Conv2d, nn.Conv3d,nn.MaxPool2d)):
            sort_name.append(name)
    # for i, x in enumerate(sort_name,0):
    #     print(f'第{i}位元素是:{x}')


    sn = []
    usn1 = []
    usn2 = []
    usn = []
    poolnodes = [string for string in sort_name if '_pool' in string]
    print(poolnodes)
    for i in range(len(sort_name)-1):
        string1 = f'{sort_name[i]}_out'
        string2 = f'{sort_name[i+1]}_in'
        if sort_name[i] in poolnodes:
            string1 = f'{sort_name[i]}_pool'
        if sort_name[i+1] in poolnodes:
            string2 = f'{sort_name[i+1]}_pool'
        if not len(matches[string1])==len(matches[string2]):
            if 'pool' not in string1 and 'pool' not in string2:
                str1_id, str2_id = findPairs(idx = i, str1=string1, str2 = string2, match = matches, s_name=sort_name)
                str1i = f'{sort_name[str1_id]}_in'
                str2o = f'{sort_name[str2_id]}_out'
                if str1_id != -1:
                    usn1 = [(matches[string1][j],matches[str1i][j]) for j in range(len(matches[str1i]))]
                if str2_id != -1:
                    usn2 = [(matches[str2o][j],matches[string2][j]) for j in range(len(matches[string2]))]
            else:
                usn = [(x,y) for x in matches[string1] for y in matches[string2]]
        if len(matches[string1])==len(matches[string2]):
            sn = [(matches[string1][j],matches[string2][j]) for j in range(len(matches[string2]))]
        # # print(sn)
        G.add_edges_from(sn)
        G.add_edges_from(usn1)
        G.add_edges_from(usn2)
        G.add_edges_from(usn)
    print()
    []
    # ===============================================================================

    # ========================残差部分，shortcut命名==================================
    # 根据模型自定义
    # short_name = []
    # for name, module in model.named_modules():
    #     if 'downsample' in name and isinstance(module, (nn.Linear,nn.Conv1d, nn.Conv2d, nn.Conv3d)):
    #         short_name.append(name)
    # # for i, x in enumerate(short_name,0):
    # #     print(f'第{i}位元素是:{x}')

    # # linked_nodes = ['layer1.1.conv2_out','layer2.1.conv1_in','layer2.1.conv2_out','layer3.1.conv1_in','layer3.1.conv2_out','layer4.1.conv1_in']
    # linked_nodes = ['layer1.0.conv1_in','layer1.1.conv1_in','layer2.0.conv1_in','layer2.1.conv1_in','layer3.0.conv1_in','layer3.1.conv1_in','layer4.0.conv1_in','layer4.1.conv1_in']

    # print()
    # for i in range(len(short_name)):
    #     string1 = f'{short_name[i]}_in'
    #     string2 = f'{short_name[i]}_out'
    #     string3 = linked_nodes[i*2]
    #     string4 = linked_nodes[i*2+1]
    #     # print(len(matches[string1]))
    #     # print(len(matches[string2]))
    #     # print(len(matches[string3]))
    #     # print(len(matches[string4]))
    #     shn = [(matches[string1][j], matches[string3][j])for j in range(len(matches[string1]))]
    #     shm = [(matches[string2][j], matches[string4][j])for j in range(len(matches[string2]))]
    #     G.add_edges_from(shn)
    #     G.add_edges_from(shm)

    # ==========================================================================

    # =====================Maxpool部分===========================================

    # ==========================================================================

    # ===================batchnorm部分==========================================
    abn = []
    bbn = []
    for name, module in model.named_modules():
        if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.LayerNorm)):
            if 'downsample' in name:
                abn.append(f'{name}_bn')
            else:
                abn.append(name)
        if isinstance(module, (nn.Linear,nn.Conv1d, nn.Conv2d, nn.Conv3d)):
            bbn.append(name)


    for i in range(min(len(abn),len(bbn))):
        string1 = f'{abn[i]}_mean'
        string2 = f'{abn[i]}_var'
        string3 = f'{bbn[i]}_out'
        bnm = [(k,v) for k in matches[string1] for v in matches[string3]]
        bnv = [(k,v) for k in matches[string2] for v in matches[string3]]
        G.add_edges_from(bnm)
        G.add_edges_from(bnv)




    # ==============================================================================

    # 打印并保存
    print(f"Graph has {G.number_of_nodes()} nodes and {G.number_of_edges()} edges.")
    nx.write_graphml(G, f"./model_graph/{model_name}_new_mlp_method_graph_whole_dig.graphml")
# print("Nodes of the graph:")
# for node in G.nodes:
#     print(type(node))
#     print(node)
# print("\nEdges of the graph:")
# for edge in G.edges(data=True):
#     print(edge)
